{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import cv2\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import paddle\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = paddle.get_device()\n",
    "device = 'gpu:1'\n",
    "paddle.set_device(device)\n",
    "torch.cuda.set_device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.2\n",
      "1.9.0+cu102\n"
     ]
    }
   ],
   "source": [
    "print(paddle.__version__)\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Conv2D(3, 4, kernel_size=[1, 1], data_format=NCHW)\n",
      "Conv2D(3, 4, kernel_size=[1, 1], data_format=NCHW)\n"
     ]
    }
   ],
   "source": [
    "conv_paddle = paddle.nn.Conv2D(\n",
    "    in_channels=3,\n",
    "    out_channels=4,\n",
    "    kernel_size=1,\n",
    "    stride=1)\n",
    "conv_torch = paddle.nn.Conv2D(\n",
    "    in_channels=3,\n",
    "    out_channels=4,\n",
    "    kernel_size=1,\n",
    "    stride=1)\n",
    "print(conv_paddle)\n",
    "print(conv_torch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=10, out_features=20, dtype=float32)\n",
      "Linear(in_features=10, out_features=20, bias=True)\n",
      "====linear_paddle info====\n",
      "weight [10, 20]\n",
      "bias [20]\n",
      "\n",
      "====linear_torch info====\n",
      "weight torch.Size([20, 10])\n",
      "bias torch.Size([20])\n"
     ]
    }
   ],
   "source": [
    "linear_paddle = paddle.nn.Linear(\n",
    "    in_features=10,\n",
    "    out_features=20)\n",
    "linear_torch = torch.nn.Linear(\n",
    "    in_features=10,\n",
    "    out_features=20)\n",
    "print(linear_paddle)\n",
    "print(linear_torch)\n",
    "print(\"====linear_paddle info====\")\n",
    "for name, weight in linear_paddle.named_parameters():\n",
    "    print(name, weight.shape)\n",
    "print(\"\\n====linear_torch info====\")\n",
    "for name, weight in linear_torch.named_parameters():\n",
    "    print(name, weight.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "<paddle.vision.datasets.cifar.Cifar10 object at 0x7f0e344decc0>\n",
      "Dataset CIFAR10\n",
      "    Number of datapoints: 50000\n",
      "    Root location: ./data\n",
      "    Split: Train\n"
     ]
    }
   ],
   "source": [
    "# 下载到默认路径\n",
    "dataset_paddle = paddle.vision.datasets.Cifar10(\n",
    "    mode=\"train\",\n",
    "    download=True)\n",
    "# 下载到data目录\n",
    "dataset_torch = torchvision.datasets.CIFAR10(\n",
    "    root=\"./data\",\n",
    "    train=True,\n",
    "    download=True)\n",
    "print(dataset_paddle)\n",
    "print(dataset_torch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "paddle length:  50000\n",
      "torch length:  50000\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAC5CAYAAAAxiWT3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsiklEQVR4nO2da4xd13Xf/+uccx9z5/0kKZIyRcmWrYctO7StNEGah50G/tAkQFHYRQO3MKCgaIAEDYooaZGmRT8kReL0Q4AECmLYAdK4ae3AbmA3VRwHgdvaFi1bsiRKIik+h+S87517577P3f0wVw73/i9qLmeGd+a46wcQ5FncZ599zllnz5n9P2stcc7BMAzDyB7RQQ/AMAzD2B02gRuGYWQUm8ANwzAyik3ghmEYGcUmcMMwjIxiE7hhGEZG2dMELiI/JSKvicgFEXl6vwZlGAeN+baRBWS334GLSAzgdQAfBnAdwHMAPuace2X/hmcYw8d828gKyR72/QCAC865NwBARD4L4KcB3NHJJyYm3ZGFo3s45D4gB3v4PbOnuKudd77nYV3KAfRjBlblRSO0rKwso1rd3I87fNe+HSfi4pz/C61zPWoniP02Sl8ibI0T/mV5dLTgbY+U8tSmp4xBI4q4f9fz99Xe9bqdlGzNVofHkfo7l0oFajMxMUq2JInJpg0kDcaaJMrUpoxfFG8Z5J3WDeplA/it3pVvvbm4jvL6FjXdywR+HMC127avA/jgW+1wZOEoPvk7v7+HQ/4dol159XHwbfp+2WEvkbOD7Dt4/4pj7jznqrZejyeZXi99y20ASFPf9m9//V8r49wVd+3bcS7C0QeLnq3VanM7GfO2nXIdkxyf6+QMT24/8MHT3vbj772f2rTTBg9WeQZKRZ5QW82mv93i+7S2VCPbhdcXyVbfannbTzzxMLX50IfeT7aF+XGypSmPo1qte9szs1PUJlZ8LxL+waV0j15wydIBp5HQRwF+xrSuJPZ/cP3zf/if1f7vuYgpIk+JyFkROVvZLN/rwxnG0Ljdt8M3TMMYBnuZwBcBnLxt+0Tf5uGce8Y5d8Y5d2ZyYmoPhzOMoXHXvh3F2f7Nzsgme1lCeQ7A20XkAWw790cB/JO33EN4rU37lX2gZQ6ljQywhLLf3MslmV1fmz3uO0hfg7TTl1CU+6bdSwmXvnY+/j7eibv2bRGHKPgdPZfj96N6zf9VXzuvOGFjPuHlGPT85ZFbt65Sk0IpR7YjR4+RrbzCSyGdVtdvU+E2ly4t8X5dPu/3f/BRb/vM+x+kNuMzPB21enWy9ZT1i3bXX6Jp93jpqJhjjSBJ+PrEEY8jXNWPFeeOleUYnZ2fk5Ao0r171xO4c64rIr8A4C8BxAA+5Zx7ebf9GcZhwXzbyAp7eQOHc+5LAL60T2MxjEOD+baRBSwS0zAMI6PYBG4YhpFR9rSEcrc4p38XGbJ7YVAT7nbZlYI2rvAb5kEDA3YbQKDt1xv02+1BAhR2KVjqbQbqCr0eNwyvq/atOImmgx3unhBFEUZG/G+pG65J7eIkFDo5UGVsgr/JTpVXrZur6952vcci3cm3sWBZrXGgzcULK2RbC/rvgoXBXIGnkNMPnCTbfQ9MettxicewVuMxRODrE/f4+tRqW962ExZ9R4slshULbJNikWydIKAor4yroHhgrAUiBe/NWuAWo09k9gZuGIaRUWwCNwzDyCg2gRuGYWSUoa6Bq6jZfHbeTVs30tZlQ5MeNKIEBrR5DU2zjY6GOSoGzdGijcPf7ip6QUcZQ6rkCdHWn8eCsQ4a7LPb/CgDXwnlXg4S8HXYCOM4tOjMfN5vlMvxIzg+yfk/Kq1Nsl2/Wfa2N2u8Bj5amiHbVoXXmi+cv0G2Xur72oMPH+H+x/kc52Z53Xdq0g+Y0TSNlJfFMZLj9e5ECeQZCR+La2XuLMeaRDOukK10ZI53nfKfnV7C4+8ogTw9TZcLt7VnLgz2uYPCY2/ghmEYGcUmcMMwjIxiE7hhGEZGsQncMAwjowxZxHSczFxZwA8TyW+U16lNqlQaWZhfIFs+54snmhjWTbtku7XMWdZGR/ij/1DEHFRs00ScEK2vlbVVstW2tsg2PTlFttGSP34tqGov2RXDM3KKuLq+vka2SoWFpJnpWW97dGyM2hwmRAT5QJDUiuG4NPR/vsetNottm2XOyjc97gecjIxyAErl6gbZ0rZSMafKx3z8PX7RhXc+dJz737hFtjwXjgHWfV9wPX53zOdZsIy0POs1zoo4WvNVzEKV9ytG/AHAtVUef/oQt5t59AFvu6EU3VBKZyBSP5rwzz1WqiEhyD7olGA3wN7ADcMwMotN4IZhGBnFJnDDMIyMsqc1cBG5DKAKIAXQdc6d2Y9BGcZBY75tZIH9EDF/zDnHypqKQII4pHaHBYMLly5621cXuVSUtqT/nkcfJ9vxY/d521HMv3SUN1lEu3TlCtnefprLQA0iympiZBhpCADLq36EXKejiE2KMtZssnwyeb9SoTyI4lxaYqE2SdglFubnuV2eS1Gtrix721euXKY2G4qIWa+zQDc16UcQPv4Y39tSIMqGvrUPDO7bwpGXExNcSb5Z9yMqm80WtZkCX9u3zU2QbXLEFy3fcR/f86TLYtvN9WtkOzrCx5yPfV/YPM+CX2ODBfSusBj5+gX/mR5b4Gtz8sR9ZIsa7BuTirCZD9I1Njv8fKkiY8TRqy3lGVve8ueIXp6jTRPhZydWnqdBIjElmKe05x6wJRTDMIzMstcJ3AH4XyLyLRF5aj8GZBiHBPNt49Cz1yWUH3bOLYrIAoBnReRV59zf3t6g7/xPAcC88p22YRxS7sq3cwX7ZdYYPnvyOufcYv/vZQB/DuADSptnnHNnnHNnJiam9nI4wxgad+vbSd4mcGP47PoNXERGAUTOuWr/3z8J4D+81T6ddhuLNxY929VFFlRW1n0xr91hoUdLVnoxED8BYCuIUhwd42jKjQ0W1lbXl8mWOhaEwqg5LWBqanKSbNOKbS0Q+FbWeFxFJf3oliIC1jY5/WgrEIxvrfA5jo1xKtMexVgCzTbfk9VAhL1ylYXgVpujXjV5ZiQouaWVjbsHouV2v7vw7UgE+YIviDU3+Bq5Lf/8m5ss4o8fZ2Htve/gKMjZoETY0hL7QbnK/U8lLIhKgX17vucLm80VFvvTCgvtXeVDgc2Wv68oPjs5PUW26Rxfw6jHxyzXfVu9wX6WKuKwK4yQLWnw+KurVW+7EXP/hRLPLVGe72Ux8p/hvFJWT3KDpVPeyxLKEQB/3ldQEwD/xTn3P/fQn2EcFsy3jUyw6wncOfcGgPfs41gM41Bgvm1kBVu4MwzDyChDzUZYbzbwwqsvejYtCCUO1tDCLG8AECnli7S17PpWmLlMKYWUKmuVSrutKmd2e+Oiv2+1zn2NjnCWuIkJzq63UfHXrbeUII9IyV4HJSjoVUUPqNeCdUhlCbmyyef4khLo1O3wGmBpxA/gKBX5vKMer8n2FOHg/qP+mm+3y8dbWvXX8DtdpSbXkIijCJNFfz11ZorXNt824WsfnSavyxby7O/dCq9RJ8H5Ll+6yeNKWNO4/z4uGVZrcUDOtYv+9T1931FqMzXOz0mtViXb/Jx/zHe/jwOzThxhXaitBB1Vqjxn1FLfNlrite20x9f1ylXWgZTHDmnFDzxajTl742qbx1VtsU8ujPvnOTfNmsTojD8/tBXtCLA3cMMwjMxiE7hhGEZGsQncMAwjo9gEbhiGkVGGKmKKALlAhJteYEFlc9MXQapKCaWi8oF8TgkgSFNfROgIiy5akMisUsJrVMk21ur6/RUSFq56PRYgNmscaJMPMvwdWzhCbQoFFr2qSkmySpVtrUAwHlf6QsTXolTka50q5xmWuRsf4ax0RyZYVJOeoqb2/LFeuvoaNVla8QOH6g0W4oZFXgQng0CeyRlFfE/9a95rTVOb5Q0WyK6cv062whFf/Do1N0NtjisZCidyfN8vXlHK33X9e/DCtcvUZoZjV/DwDIuRR+/zfXlxicXDb77yBtkqy5wMcvVWmWzttv/sH1eyN+aL7LPnz10i2+nT7Edx2b+2ix0ORHrl1g2y3SzzWGdH/eszOc4XMT/u+87GBgvDgL2BG4ZhZBabwA3DMDKKTeCGYRgZxSZwwzCMjDJUETOCIB/5AspokSOm6lVfIBDl58z8LAs2Y0Ul02AgIlRbLD5spRwt1djidlq2sbDUUT7HYy2NcfmovCLCSs/vq6OMtdtkkXEkZWE2rwhVruWLOKkSAdnqKWFoSvBnT8kI1+35Qli1zuPS8gdq97e67kdsbiiib6Phi2xaROewSADMB2fnahx1Wt8s+/tFfL2P5DiC9fgRFuWOTfvtikUWjRMly+B6hT8KcI5tR6d9v13f5DZjwsLgA0qk5+xRvyzftatlapOfP0228bETZLteeYFsL59/3dt+4QKXf1s4ygL68bkpso1O8/hXgsyS62W+Fs0a+3u3yc95OYiqbNT4OS/W/XvZ6VhJNcMwjO8rbAI3DMPIKDaBG4ZhZJQdJ3AR+ZSILIvIS7fZZkTkWRE53/+boxEM45Bjvm1knUFEzE8D+D0Af3yb7WkAX3HO/aaIPN3f/pUde3IO0vVFlY6ygJ8EeU7HlTJoI0qq0jjhn0dx4vc1W2BBcaLL4k+1wtFYqSIWJkHqz46SPrJBKW0BdPiYcSDClRt8bTopi1LFPAuWCyWOJD05M+uPq8NjrSpl40JxEgCuVZbIlsSBoKX0tVTmCNFGSyn7FZR2G1VKX+Vj/9rHSlrdHfg09sm380mC+wNhfWuTzz+d9KMsj9/H4uTcBPtGThGNcwXftxtNjuA8d52jA6+XOTL4yQ8+SLbZov+sVJdmqY2kfO+6XRaTXSDMrqzx87WlROm2u9w/xpUI4nFfLFxZ5v02lJJwpWmeDzbbyjMcRDH3Wiw+lxIWLEdLPMXGzn9OEsVvu51A7nd6+cAdPb5fiXs9MP80gM/0//0ZAD+zUz+Gcdgw3zayzm7XwI84597MHn8L2zUEDeP7AfNtIzPsWcR02+WS7/gBrog8JSJnReRsXVkSMIzDyt34dq1xcNWAjP9/2e0EviQixwCg/zenFuvjnHvGOXfGOXemNKKkLjOMw8WufHtsRFmXNYx7zG4jMb8I4OMAfrP/9xcG2SmOI0wHtSBT5UfIiPgPw7QSYZl2WLBpdPhlqdn2xYZUSR07qQiivVCQA1BRBMRSUGevowidzTqPda3D6SHDcpe5ERZFxko81lRJV1vt8TjCdLsNRZwsFrj/epvHr+yKcBRhKl8AaCu1LbvKNeu0fREqp9SJrAVCklPu7S7YlW+LCIoFX3yUMR7zaM5PJbowxeLUPOvPGFN8tAv/+jYUP1upsB+/scbtZuf4oNNJUAfySpnarK6EEgJQavA5NYq+WLuhRDIur3Nq165SQ/XG9atkqwYiYxrxs6P52QWljujRERbMx4P0rqPKz+uVLfZ3LX4yTvydm20+x1bdt6XK3AMM9hnhnwL4vwAeFpHrIvIJbDv3h0XkPIAP9bcNI1OYbxtZZ8c3cOfcx+7wXz+xz2MxjKFivm1kHYvENAzDyCjDzUaY5DC64Gcl2yhzljkXZDiTHK9njRR5vTPWglCC9VvlG320UyXDn5LNL5/w5ZJg3TXJ85pjoqzHbTheh5TI7398nIMMOsoJaEn4usrqmwsyJ4ZryABQKHAQydQ4r48eO7bAYwvW3SvrvD4aKfeoqARg9YK18rJSVq8WrB2myrr/0BCgF5zGkRN83SYK/j1OFH0Bwr7XU7SJfOJfo6jD12hKWas9qpQL3Fplv/puUMbutStc3uzEcfaD2VnWrDaD+xcrM0+8xdeiq/iocppI277f5pSPhwT8bM6PcyDVSSXTaS7na1atFj/Tqx3243KXr2sazDftljJvBWved9J37A3cMAwjo9gEbhiGkVFsAjcMw8goNoEbhmFklKGKmCJAHATIqIEdQTazVMnElkSKYNVlEaQTZD8cUQJhag1tDGwrRByg0O747cZGuf+5mSmy5ascyIPIvzYLM5z97ebKGtlaTQ4EGFFK1eUSv//6VoPaRAkLPaJkSyspZb/yI76Q1NpSMs7V2JZXhN9mIP5okeqDCj3DYKSYx2MP++W/SnwLEAdjbNeUrJSJUrovz/egXvUzQm7cYp/arLFAtrrOwuBf/Z/XydZM/Hv8zr/3Y9Tm0Ye45JlbuUC2i0HJs1gR2SMl0GZ1mT9y2Nhgv+20/X21uJeSsPHxB1mEPcaPHTbLQTZC5WOCH3js3WT7QUUQ3az6z8CmkqFzY9O3ffUal5ED7A3cMAwjs9gEbhiGkVFsAjcMw8goNoEbhmFklKGKmGkvxUbNX5zvKWKki3yhp6dEVVUbLGSsr61wX0H3pRILfiUl+nBinMs79ZQUfKGIudlkcWOryyJFnONLXwwiPTfWN6hNqoirPSUCsVrlcLWw5JwoouzizVtkm5jkaDUtO1pa98fWbPI9CoVHACiAx+GCc1KGeqjePnK5BPcd86OM0WVRsRNE3cXjLAZvbLKPOuVaNlZ9v1q8xfn2r9XYz86+fo37n2Ax75/9/D/1tj/w/jPUprzIgmVljX1oecW/FjeX+ZmoOx6rUzJaFnM8H0SBOKxVYjs+zWL5/Qt8XUcnOHz11lKQybPNY/37j7yLbHPz82QbhLB04kvPcwZG4HA9A4ZhGMZdYBO4YRhGRrEJ3DAMI6MMUtDhUyKyLCIv3Wb7DRFZFJHv9P985N4O0zD2H/NtI+sMImJ+GsDvAfjjwP67zrnfvpuDxVGM6ZKfyjJWUrl2Y18MKxRZZMznp8m2tsZRiu0gYkrAUWg5pXyaK/C4Ckp0YzOI9Gy1WWRsNjhCdH6OI7Qk+Hm6uMTlnuotVmcmlPSgBaUE2VbFjwCrK8JYXRmrJt5OTPAxQ+E0ViI4Z6cmyVZQ0gVLKFoqrxq9II9uokSR7sCnsU++LXEe0eT9ni1OlbynQVrVVoX9ZXmV0/Cur3FpzurKorfthK/j8hbfu8Uy3/f77+Pww5On3+ltT88dozY5Jfp5a/k62bqpL9amTU5NOz7GYvzjT5wk29T4FNkqa75IWl5iIXVuhv3jvllO2Vyus3D62g1fIJ4+dpr7n+NrGKabBoBiUB4vjE7f3jHauQ0GeAN3zv0tAPYow8g45ttG1tnLGvgviMiL/V9D+XW4j4g8JSJnReRsTUnKbxiHkLv27TWlSK9h3Gt2O4H/PoAHATwB4CaA37lTQ+fcM865M865M2PKr/qGccjYlW/PTplvG8NnV4E8zrnvpUETkT8E8BeD7CcCxLlgLUcJ0AgzCMZdXhsbH1VKf41x8E0z76/RdZQSRw2lrJVTMqPlC0qARbAOm1dKsY2P8tp5GLQDAM2Wv842vzBHbSqbHByiBcwoS9mIg2yHYyW+hmPKWCPhn/Nph6/j5Li/vj09xev8jQZnI+x2lOsaOEa3qwQOhfdyH7IR7ta343wJ0297wh9fjdetX736nLe9ePkGtbl8ibWPpeUlsm0FQXFRkbMYXi7zGnsn5uChWPHHqbHAF1pKYNIWB5ut3OJz6nZ87enoHP/Ae+zxo2R76BQHwqzf5FWvuvg+FB/nZ6fZYN8+9wqvlb94+RLZMOGvb//4B5+gJvNzPFYt4C3MdNpus/9L8Mzta0k1EbldzfhZAC/dqa1hZAnzbSNL7PgGLiJ/CuBHAcyJyHUA/w7Aj4rIEwAcgMsAfv7eDdEw7g3m20bW2XECd859TDH/0T0Yi2EMFfNtI+tYJKZhGEZGGWo2Quc4y5wm+oUBIAUlWyDA4kAcsa0XZDMTp+wnPIZ8gS+NFifSC0RY11PKsynBPZtNzhzXC4SKOOafr8VQBAawUea+Rgoc1FEs+SJXqgTodDo8Vg1dnPHFGKUSHlotDqTqpcoxScRhtTuOw3ukKOLDQiJIzg8KWbzOgWXXb/ii39XrnEGzlfJ5NFK+70tV/x7UK/wp41qH92s79qtukwXKK+d8wbWcY2G/u8njr65ytsOc+OL1qdMscJ9Y4Of8xtVXyTY/zWLhTBBk9PVzPIbnnrtCto0NdtLR+SNk+4mP/KS3/eSTH6A2xRyPv9nkrwnqdf95jZSAt1DEpMC2N/fVzYZhGMZhxyZwwzCMjGITuGEYRkaxCdwwDCOjDFXEBDhYrtngKKQwg2CrpUT9TXDUpQsVRQBp1xfNXFhjDYBoUZF1jm7kPaEUe1PaaI2UGmFh5jKn9K4Kfsp5N5TSbp3UF6rCKFKAo8QATSzUxed6vez3pUSzaueUKhkpJVBttPuWBtGZWmm5YdHrtFG75QtnS1e5DNaVy76QdmuJRUDEfG3Xtlig3Oz6162muEalwaKx5rT5hK/dyuJ5f1ij/L5X6LFI122UyXbqlB8Z+di7OMvgxjpHro5NcgbEmmOB/m/++kVv+7nXN6lNvcHnePo0R3/+zM/+FNne/diT3nYuYsEyFPEB9mMAyOf98WsiZjhnhKLm9/ZVrYZhGMahxyZwwzCMjGITuGEYRkaxCdwwDCOjDDkS06Hb8YWndouVl3rTF17SKosPy0scObaxrggqzu9LlOjG1Ckhg1qkoSLKdQLRL8kpl1QRjRJFlMiF5dISFkA6SoTo6AiXhYJyns3guhbzLAaFZcq2bXyPcoqwGQZCFgocBdhRUgP3FIFyLIgaTZSSUmGptzuVnRoG3VYDa29817OVb3E04Nqyn750Q4merCnpRastFtXD9LGdGvtxXUnfq0UjF8c4Fe1DDz/kbZc6nDp27QZHN7790XeSbW7aF/3KSyzwrm+wH9xc5+f8W6/xMS/d8lPrdlI+n9Iop7D9oR/5YbI99u5HyFYIPnTQBPN6gyOitYc/DeaRMDod4PKNmhgK2Bu4YRhGZrEJ3DAMI6PsOIGLyEkR+aqIvCIiL4vIL/btMyLyrIic7/99x9qBhnEYMd82ss4gb+BdAL/snHsEwJMA/qWIPALgaQBfcc69HcBX+tuGkSXMt41MM0hBh5vYLu4K51xVRM4BOA7gp7FdzQQAPgPgbwD8ygD9+dtKfOPami+WXLzIYpBzSrrXiEWKiUlfDJiaVkRMLa2qElUVR3zMYhBROVHguns5JRRTq3E3FhwyX1LEyeOcSlMU8S5R0s52gjqWpSJHk9WVyD0twDGviLUSHDKv1v3k6xqKqwAwPTnhbY+F9RnBQtKXvqylHb4z++nbrUYDF196wbNdUPw2F6Rk7Tm+HltK1GWiCMLFMd8/blU4fa2GJrRPTnFk89vu99OqxspHAtOFE2RrR+y3X/+mn5r2tVdeoTbrGyyW39pg36h0WNBrIRD9wILuQw+eJtt9x06RbW2Nr/9kcHmSiKNlwxqlgF7LNUyXnSgfBEggfmqRyMBdroGLyCkA7wXwDQBH+g8AANwCwEl0DSMjmG8bWWTgCVxExgB8DsAvOee8RANu+3VSTQsiIk+JyFkROau9WRjGQbMfvl2pKzlHDOMeM9AELiI5bDv4nzjnPt83L71Zwbv/N2eiAeCce8Y5d8Y5d2ZU+Q7TMA6S/fLtydLdLd8Yxn4wSFV6wXah13POuU/e9l9fBPBxAL/Z//sLA/SFOKhLtnBkjtqNjvof4S8v8dre2gYHKMzN8kM0FpQza3EsAqaO8/pfaZRLPoVrVwDQa/j995q8punq/IF/R8lG2BC/L1ECeWanJrl/LbtcntdMwzVvLThAC+TRXj9jLbAgOCe1f2WwqVLGLVxjT5T19FC7iKK7C+TZT99udzpYXFzybNdurlK70XH//KOEr0eS8Prq3DTf90rH37enBJ8VikWyxYrm8553cfDNaOT3f/3KZWpzbfEm2b57ZZ1sz7/mBzBtNniNuqtk2mxrgV/qtOXf+67ie2ubHBT0jW+9QLaZKSUIqOjfk6lJfhlNlOA5bX07F2hWxTzPW/ngWU2VLKHAYJGYPwTg5wB8V0S+07f9Grad+89E5BMArgD4xwP0ZRiHCfNtI9MM8hXK13DnarE/sb/DMYzhYb5tZB2LxDQMw8goNoEbhmFklKFnIwyDL1JFeBkb94WX97//cWrz/POv8gFySkmjIDClXFGEhrEpsj3+6MNkm56ZINvmhv/x/osvfJvaNJUfkz1FxETeb5gqd2dKuV7iuK9cwsLIyIiffVATP7XSTVq2wEgJZKAyaKIEHygml+NzSoNj9hQp1d0hQ9vBIOgG16SjXKONINijNMFieVLkG3/0CEfzd1d8UW60xJ/pbm1wabEfeOK9ZHvnA+8g21ee/bq3/cbLL1Ob1y9cIttqnW/yVur7Y1d5d1T0SijV9lRxPHV+f3Ul2Ofl89fJduk6l7Qr5vm+FQJR/aEHuSTczAx/DDE+xkFNc7NT3naifByRCz5gaLWVjKmwN3DDMIzMYhO4YRhGRrEJ3DAMI6PYBG4YhpFRhipiwjl0g4iitrI4v9X0Ra3RcY6MOn6MIzif/84FsiU9X7hLYo5MW7rO4s/X/5bFmR/78PvI9q7HHvW2T9zP4kZXUWec8vlx6vxrk8+xmDI5zkIqFFEnSTgqMZf4P6+1cWnZCMdKnAmw3eLo0jg4JynyfUuVaLuu4gPNoITY5maZ2mxu+mG1WobHYdHpOaxs+SJ6qgjJ+XjK226liqhe4P2UpHZUmqurlGJDT8nKd4p99NwLL5Ltq3/91/7xlEjXrRY/T80uHzP8WEErI6aK3kokoxaJubnl+221xb4QJ1xCsK6MNY543zCBY7nFgqhAy2DKfZ04cdTbnp9nIXth2hdEmyZiGoZhfH9hE7hhGEZGsQncMAwjo9gEbhiGkVGGKmKmvR5qNV8wrCupVqs135ZTUqMmeRbDxsdY/GnU/X1n5zla6vjxBbIt32yQ7cv/4xtku/9BP0LuQ//gSWozO6vlQedzorJJwgKIKOJnV0kPqlVg6gTNqlU+x6//b06vOT3HIkv36hWyPTThi67jjz1CbeI5Fp81cWks59/LSBEEW+29pZPdT5rtDl65csOzVZSUqScXZr3tvCL0VhucKrlc5YjKa7d8Ebe8yW1KIyxAn3uZBctmhVPfVlr++Gvg+9Ros+/1ulx6LQkE80iJ+I2VUm9Upw9AW8ms2gz67zhu1FFE+0KRIyUTpdQgAh+tp8pz2FHKNSrXYv1V/9lJzl+mNmMl/1mqVNknAHsDNwzDyCw2gRuGYWSUHSdwETkpIl8VkVdE5GUR+cW+/TdEZFFEvtP/85F7P1zD2D/Mt42sM8gaeBfALzvnnheRcQDfEpFn+//3u8653753wzOMe4r5tpFpBqnIcxPAzf6/qyJyDsDx3RwsTbvYrPj18rT0pTNTQRRSk4WAmxtlso1Nc19Rzo/QWqq8QW2ur71GtvoWizOtFgsj337VFxpvrrC49773sZg3Mc7iyUQQZZlTxL1Iibok8RNAquThrNb99KPlTY5AXd24RbZLVy6SLdfiiM1m4t+nYoVrI+ZP3U+2OGahKg7qQnZTPsdwvzBV8U7sp2+3Ol1cDGpgdoUfr0rZF+inlVqXiPnerS5z2tNaKxAGcxwVmSqi97XFG2SLRalHGYy/0eF7rqX5jZVUybkgRWssHGXcU9YDUmWRYH2TBb160x9/ocT9J0pkc6pEqrba/GFFt+f3n8uxOCzK8yrK9ekFfXWUMNv1sv8shVG3b3JXa+AicgrAewG8+TnGL4jIiyLyKRHhhMWGkRHMt40sMvAELiJjAD4H4Jecc5sAfh/AgwCewPZbzO/cYb+nROSsiJxtNPizNcM4aPbDt7XfEAzjXjPQBC4iOWw7+J845z4PAM65Jedc6rZ/f/9DAB/Q9nXOPeOcO+OcOzOifJNqGAfJfvl2EtsHXcbw2XENXLbrF/0RgHPOuU/eZj/WX0MEgJ8F8NIAfSEOsovFiuOHH/TPjnOgTaHAARC5vJK5LAhyWV/boDYblQrZtrZ4nW1LydTXDaIKbty8Rm02v8bHnJvh38qnpqa87bTH62dahkKtsJi2Lh6un4fXZruRktEu5t+clCR0eLnrjze/weddavL6otPWPoM1fC0LYxKsgbeVdfm3Yj99O+05lFvB9VTWgsMh1pq8th3leb+OUiIsDTJrxopvtNp8TRoN1pS0tfLU+b4QPrsAMKpkToyE14JzQZCV9vtKpMwF1Qb7Y6PF69bdIKtjpJR100r3RQlfs0gpcRbG1PUUPUB7EkWLqAscPlLeo3ukDeovCIN8hfJDAH4OwHdF5Dt9268B+JiIPAHAAbgM4OcH6MswDhPm20amGeQrlK9Bf8n70v4PxzCGh/m2kXVs4c4wDCOj2ARuGIaRUYZbUg0CF2ThaysZwpZvcDBJSCk/mPgwXvIDZhJFWDp54hjZxsY4a2GtyoEv7a4vZmgiqVM+5q9VymQ7f94PmFkvc5sZRfwsKuKVJg6fPO7HqExMcJbENGVxpqIE5ESREmQUHDLRgo4SJYubIvSslX1heXJMyegYdqUthgyR8Io7JUgNwb1q9/h6t+tKwIxTBMpg326NBeLIKUK1Io67mO9nHNzQEU2Qi7j/JOFpJY798WsBNHVFnNxSroUWMJMEImmsRAWJauMgsrTD59kNREvt1mrzj/bRQRKIwWmHg6h6AzqzvYEbhmFkFJvADcMwMopN4IZhGBnFJnDDMIyMMtySammKSs0XpzodRcxo+pF/kWPRJerw0J0iIvQ6fl8dRVBsbnKkYW2Ly1P1UhYbQqmhNKKIFoqok8/xWJPEt81MccbC6UnOXpfPs6ijCadJkOWup5SdUi4hjgZlwABAtIaB9pNTSsJpBxAlFHNi3I+0rVTWuKtAJO12lVpbQyKOIkyV/FQRlRaffyu4L1qGy7ZSMyxVombTQJBXMwMqYw2jIgFA0Zshwf1Ts73EvGOcU8oFBrdYK8VWU6IuU8d9RbESjRyck+Z6OS3qUsmEqT07lOlSDX9mU1sRKNPgOey0lejn8N4qcyBgb+CGYRiZxSZwwzCMjGITuGEYRkaxCdwwDCOjDFXEjAQYC8Q7KXBe0tlAwNJ+yqjpl5XwKBdEnaWauqEIBE4pzxUpYmQUqBma1qAJEPkRPu/JET81p8gctZEBI7REiTgN9+0pkXU5rX9NJFX2lSQQkpSumkp607TDIk4+0JZ6Shsqx3cHoWcYOADNQKnb3OLI3WYgWroui/iJch6x4tvhPdYk3EgUGVO5TFrprzBdr1YqTZQUsz2lXVjMpd5kP+h0NfFQEzGVcwrEcdGiTZVn3ynSrFOuf+jLmoiv2rSHICBX4Odr0H7sDdwwDCOj2ARuGIaRUXacwEWkKCLfFJEXRORlEfn3ffsDIvINEbkgIv9VRCnDYRiHGPNtI+sMsgbeAvDjzrlav37g10TkywD+FYDfdc59VkT+AMAnsF0M9s44BwRrmU5d2/FtqbI+p+2nr4D67bQlcG3ZV19x4p3DsWmnE66T3+kI4TqkE2V9TtkvXOffHsfO10fUtX8le50yfG39PFxvFWX9NVL6j7T1yrCNlgkv1Cnufg1833y7001xa7Xs2bQsc6OBeDM7y1kWE2UtNXV8E9Zrftm/tnL+WhBZpIT3KN1ToFSiBMKI0n+1zoFxmzVfD9Cy7aU9zWe1gDEtK2LwnKvBOOx7mk1bA4+DdfdequkUSlCTFigU9K/do7DNrtfA3TZvXv1c/48D8OMA/nvf/hkAP7NTX4ZxmDDfNrLOoFXp437NwGUAzwK4CKDs3Pdisa8DOH6H3Q3j0GK+bWSZgSZw51zqnHsCwAkAHwDwzkEPICJPichZETnbaHI1bMM4SPbLt++Uq8Iw7iV39RWKc64M4KsAfhDAlIi8uXhzAsDiHfZ5xjl3xjl3ZqTI3z4bxmFgr74drhcbxjDYUcQUkXkAHedcWURGAHwYwG9h29n/EYDPAvg4gC/s1FfP9dBo+wJHmrIgEUf+sLqOgzgiLR5H1Qp9o5aJLVKUTS1TXyJ8ubpdXwTpKsJjo83BGnmlJFkuCIpwUARF7bw1fU85pzC+oqsIOFoAUKqIRqJkiQv31IJ9tMCJruID4WXUSliFAlRPKx/2Fuynb+dzMU4tTHm2EwucOXKq6AuBJSUmJVeaINu11SrZXrlw2R9DibNXLswfJVuk3LvFJS5jGGYFVbRmbFR4XNWtLbKlwc6ivDsquiAi5XlFpGQFDRxGE9k1cVIrg6YT7KuMyynZMLVjhlCmQ2W/O/UzyFcoxwB8RrY/KYgA/Jlz7i9E5BUAnxWR/wjg2wD+aIC+DOMwYb5tZJodJ3Dn3IsA3qvY38D2mqFhZBLzbSPrWCSmYRhGRrEJ3DAMI6PIIIvs+3YwkRUAVwDMAVgd2oH3nyyPP8tjB956/G9zzs0PczBvYr59KMjy2IFd+PZQJ/DvHVTkrHPuzNAPvE9kefxZHjtw+Md/2Me3E1kef5bHDuxu/LaEYhiGkVFsAjcMw8goBzWBP3NAx90vsjz+LI8dOPzjP+zj24ksjz/LYwd2Mf4DWQM3DMMw9o4toRiGYWSUoU/gIvJTIvJav9rJ08M+/t0iIp8SkWUReek224yIPCsi5/t/Tx/kGO+EiJwUka+KyCv9ijO/2Lcf+vFnrVqO+fXwyLJfA/vs2865of0BEGM73/JpAHkALwB4ZJhj2MWYfwTA+wC8dJvtPwF4uv/vpwH81kGP8w5jPwbgff1/jwN4HcAjWRg/tnNjjfX/nQPwDQBPAvgzAB/t2/8AwL84BGM1vx7u2DPr1/2x7ZtvD3vgPwjgL2/b/lUAv3rQF3SAcZ8KHP01AMduc6bXDnqMA57HF7CdcS9T4wdQAvA8gA9iO9Ah0fzpAMdnfn2w55FJv+6Pc0++PewllOMArt22ndVqJ0ecczf7/74F4MhBDmYQROQUthM3fQMZGX+GquWYXx8QWfRrYP9820TMPeK2f1we6k95RGQMwOcA/JJzbvP2/zvM43d7qJZj7I3D7BdvklW/BvbPt4c9gS8COHnb9h2rnRxylkTkGAD0/14+4PHckX619c8B+BPn3Of75syMH9hdtZwhY349ZL4f/BrYu28PewJ/DsDb+2prHsBHAXxxyGPYD76I7UotwIAVWw4CERFsFyM455z75G3/dejHLyLzIjLV//eb1XLO4e+q5QCHZ+zm10Mky34N7LNvH8Ci/UewrRpfBPBvDlpEGGC8fwrgJoAOttelPgFgFsBXAJwH8FcAZg56nHcY+w9j+9fIFwF8p//nI1kYP4B3Y7sazosAXgLw6337aQDfBHABwH8DUDjosfbHZX49vLFn1q/7498337ZITMMwjIxiIqZhGEZGsQncMAwjo9gEbhiGkVFsAjcMw8goNoEbhmFkFJvADcMwMopN4IZhGBnFJnDDMIyM8v8AEr57z2pm5wUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(\"paddle length: \", len(dataset_paddle))\n",
    "print(\"torch length: \", len(dataset_torch))\n",
    "plt.subplot(121)\n",
    "plt.imshow(dataset_paddle[0][0])\n",
    "plt.subplot(122)\n",
    "plt.imshow(dataset_paddle[1][0])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[1], dtype=float32, place=CUDAPlace(1), stop_gradient=True,\n",
      "       [2.34860134])\n",
      "tensor(2.3486)\n"
     ]
    }
   ],
   "source": [
    "x = np.random.rand(32, 10).astype(np.float32)\n",
    "label = np.random.randint(0, 10, [32, ], dtype=np.int64)\n",
    "\n",
    "ce_loss_paddle = paddle.nn.CrossEntropyLoss()\n",
    "ce_loss_torch = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "loss_paddle = ce_loss_paddle(\n",
    "    paddle.to_tensor(x),\n",
    "    paddle.to_tensor(label))\n",
    "loss_torch = ce_loss_torch(\n",
    "    torch.from_numpy(x),\n",
    "    torch.from_numpy(label))\n",
    "\n",
    "print(loss_paddle)\n",
    "print(loss_torch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====class ids diff=====\n",
      "[[0], [8], [6], [2]]\n",
      "[[0], [8], [6], [2]]\n",
      "\n",
      "====socres diff=====\n",
      "[[0.9605070352554321], [0.9590041041374207], [0.9971153140068054], [0.9951398372650146]]\n",
      "[[0.9605070352554321], [0.9590041041374207], [0.9971153140068054], [0.9951398372650146]]\n"
     ]
    }
   ],
   "source": [
    "x = np.random.rand(4, 10).astype(np.float32)\n",
    "label = np.random.randint(0, 10, [4, ], dtype=np.int64)\n",
    "\n",
    "score_paddle, cls_id_paddle  = paddle.topk(paddle.to_tensor(x), k=1)\n",
    "score_torch, cls_id_torch = torch.topk(torch.from_numpy(x), k=1)\n",
    "print(\"====class ids diff=====\")\n",
    "print(cls_id_paddle.numpy().tolist())\n",
    "print(cls_id_torch.numpy().tolist())\n",
    "print(\"\\n====socres diff=====\")\n",
    "print(score_paddle.numpy().tolist())\n",
    "print(score_torch.numpy().tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 1, paddle lr: 0.010000, torch lr: 0.001000\n",
      "step 2, paddle lr: 0.001000, torch lr: 0.000100\n",
      "step 3, paddle lr: 0.000100, torch lr: 0.000010\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/python3.7.0/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:134: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
      "  \"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\", UserWarning)\n",
      "/usr/local/python3.7.0/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:370: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  \"please use `get_last_lr()`.\", UserWarning)\n"
     ]
    }
   ],
   "source": [
    "linear_paddle = paddle.nn.Linear(10, 10)\n",
    "lr_sch_paddle = paddle.optimizer.lr.StepDecay(\n",
    "    0.1,\n",
    "    step_size=1,\n",
    "    gamma=0.1)\n",
    "opt_paddle = paddle.optimizer.Momentum(\n",
    "    learning_rate=lr_sch_paddle,\n",
    "    parameters=linear_paddle.parameters(),\n",
    "    weight_decay=0.01)\n",
    "\n",
    "linear_torch = torch.nn.Linear(10, 10)\n",
    "opt_torch = torch.optim.SGD(\n",
    "    linear_torch.parameters(),\n",
    "    lr=0.1,\n",
    "    momentum=0.9,\n",
    "    weight_decay=0.1)\n",
    "lr_sch_torch = torch.optim.lr_scheduler.StepLR(\n",
    "    opt_torch,\n",
    "    step_size=1, gamma=0.1)\n",
    "\n",
    "for idx in range(1, 4):\n",
    "    lr_sch_paddle.step()\n",
    "    lr_sch_torch.step()\n",
    "    print(\"step {}, paddle lr: {:.6f}, torch lr: {:.6f}\".format(\n",
    "        idx,\n",
    "        lr_sch_paddle.get_lr(),\n",
    "        lr_sch_torch.get_lr()[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PaddleModel(\n",
      "  (conv): Conv2D(3, 12, kernel_size=[3, 3], padding=1, dilation=[0, 0], data_format=NCHW)\n",
      "  (bn): BatchNorm2D(num_features=12, momentum=0.9, epsilon=1e-05)\n",
      "  (relu): ReLU()\n",
      "  (maxpool): MaxPool2D(kernel_size=3, stride=2, padding=1)\n",
      ")\n",
      "TorchModel(\n",
      "  (conv): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  (bn): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU()\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class PaddleModel(paddle.nn.Layer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = paddle.nn.Conv2D(\n",
    "            in_channels=3,\n",
    "            out_channels=12,\n",
    "            kernel_size=3,\n",
    "            padding=1,\n",
    "            dilation=0)\n",
    "        self.bn = paddle.nn.BatchNorm2D(12)\n",
    "        self.relu = paddle.nn.ReLU()\n",
    "        self.maxpool = paddle.nn.MaxPool2D(\n",
    "            kernel_size=3,\n",
    "            stride=2,\n",
    "            padding=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class TorchModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = torch.nn.Conv2d(\n",
    "            in_channels=3,\n",
    "            out_channels=12,\n",
    "            kernel_size=3,\n",
    "            padding=1)\n",
    "        self.bn = torch.nn.BatchNorm2d(12)\n",
    "        self.relu = torch.nn.ReLU()\n",
    "        self.maxpool = torch.nn.MaxPool2d(\n",
    "            kernel_size=3,\n",
    "            stride=2,\n",
    "            padding=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv_layer(x)\n",
    "        x = self.bn_layer(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "        return x\n",
    "\n",
    "paddle_model = PaddleModel()\n",
    "torch_model = TorchModel()\n",
    "print(paddle_model)\n",
    "print(torch_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====paddle names====\n",
      "['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias', 'bn._mean', 'bn._variance']\n",
      "\n",
      "====torch names====\n",
      "['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']\n"
     ]
    }
   ],
   "source": [
    "print(\"====paddle names====\")\n",
    "print(list(paddle_model.state_dict().keys()))\n",
    "print(\"\\n====torch names====\")\n",
    "print(list(torch_model.state_dict().keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[6], dtype=float32, place=CUDAPlace(1), stop_gradient=True,\n",
      "       [1.        , 2.        , 3.        , 3.        , 3.        , 3.50000000])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/python3.7.0/lib/python3.7/site-packages/paddle/tensor/creation.py:125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  if data.dtype == np.object:\n"
     ]
    }
   ],
   "source": [
    "def clip_funny(x, minv, maxv):\n",
    "    midv = (minv + maxv) / 2.0\n",
    "    cond1 = paddle.logical_and(x > minv, x < midv)\n",
    "    cond2 = paddle.logical_and(x >= midv, x < maxv)\n",
    "    \n",
    "    y = paddle.where(cond1, paddle.ones_like(x) * minv, x)\n",
    "    y = paddle.where(cond2, paddle.ones_like(x) * maxv, y)\n",
    "    return y\n",
    "\n",
    "x = paddle.to_tensor([1, 2, 2.5, 2.6, 3, 3.5])\n",
    "y = clip_funny(x, 2, 3)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
